package com.my.thread.transaction.service;

import com.my.db.test.mybatis.Employee;
import com.my.db.test.mybatis.mapper.EmployeeMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

@Service
public class TransactionThread {

    @Autowired
    private EmployeeMapper employeeMapper;
    @Autowired
    private TestTranService testTranService;

    @Autowired
    private PlatformTransactionManager transactionManager;

    //事务集合来进行多线程事务控制
    List<TransactionStatus> transactionStatuses = Collections.synchronizedList(new ArrayList<TransactionStatus>());
    //由于这个中方式去前面方式相同，需要等待线程执行完成后才会提交事务，所有任会占用 Jdbc 连接池，
    //如果线程数量超过连接池最大数量会产生连接超时。所以在使用过程中任要控制线程数量。
    @Transactional(propagation = Propagation.REQUIRED,
//            transactionManager = "platformTransactionManager",
            rollbackFor = {Exception.class})
    public void updateStudentWithThreadsAndTrans() throws InterruptedException {
        //查询总数据
        List<Employee> employees = employeeMapper.selectList(null);
        // 线程数量
        final Integer threadCount = 2;

        //每个线程处理的数据量
        final Integer dataPartionLength = (employees.size() + threadCount - 1) / threadCount;

        // 创建多线程处理任务
        ExecutorService studentThreadPool = Executors.newFixedThreadPool(threadCount);
        CountDownLatch threadLatchs = new CountDownLatch(threadCount);
        AtomicBoolean isError = new AtomicBoolean(false);
        try {
            for (int i = 0; i < threadCount; i++) {
                // 每个线程处理的数据
                List<Employee> threadDatas = employees.stream()
                        .skip(i * dataPartionLength).limit(dataPartionLength).collect(Collectors.toList());
                studentThreadPool.execute(() -> {
                    try {
                        try {
                            testTranService.updateStudentsTransaction(transactionManager, transactionStatuses, threadDatas);
                        } catch (Throwable e) {
                            e.printStackTrace();
                            isError.set(true);
                        }finally {
                            threadLatchs.countDown();
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                        isError.set(true);
                    }
                });
            }

            // 倒计时锁设置超时时间 30s
            boolean await = threadLatchs.await(30, TimeUnit.SECONDS);
            // 判断是否超时
            if (!await) {
                isError.set(true);
            }
        } catch (Throwable e) {
            e.printStackTrace();
            isError.set(true);
        }

        if (!transactionStatuses.isEmpty()) {
            if (isError.get()) {
                transactionStatuses.forEach(s -> transactionManager.rollback(s));
            } else {
                transactionStatuses.forEach(s -> transactionManager.commit(s));
            }
        }
        System.out.println("主线程完成");
    }

}
